"""
Phase 9: Paper Figures
========================
Generates four figures for the paper:
  F1: Forest plot — Z₁ coefficient across all DVs with 95% CI
  F2: KAOPEN gating — Z₁ on gross_ifi vs ca_gdp by KAOPEN tercile
  F3: Bilateral vs multilateral coefficient comparison (paired bars)
  F4: Mediation path diagram with coefficients
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_FIGURES = PROJECT_DIR / "output" / "figures"
OUT_FIGURES.mkdir(parents=True, exist_ok=True)

DEMO_VARS = ['Z_1', 'Z_2', 'Z_3']
EBA_CONTROLS = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw', 'kaopen']


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def run_gls(df, y_var, x_vars, label='', quiet=True):
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        return None

    gls = PanelGLS()
    gls.fit(sub[y_var].values, sub[x_vars].values,
            sub['iso3'].values, sub['year'].values)

    result = {
        'model': label, 'dep_var': y_var,
        'n_obs': gls.n_obs, 'r_squared': gls.r_squared,
    }
    for i, name in enumerate(x_vars):
        result[f'{name}_coef'] = gls.beta[i]
        result[f'{name}_se'] = gls.se[i]
        result[f'{name}_p'] = gls.pvalues[i]

    if not quiet:
        print(f"  {label}: Z₁={gls.beta[0]:.4f} (p={gls.pvalues[0]:.4f}), N={gls.n_obs}")

    return result


def main():
    print("=" * 70)
    print("PHASE 9: PAPER FIGURES")
    print("=" * 70)

    df = pd.read_csv(DATA / "net_gross_panel.csv")
    df = df[df['year'] <= 2024].copy()
    controls = [c for c in EBA_CONTROLS if c in df.columns and df[c].notna().sum() > 200]
    base_vars = DEMO_VARS + controls

    # ══════════════════════════════════════════════════════════════════
    # FIGURE 1: Forest plot — Z₁ across all DVs
    # ══════════════════════════════════════════════════════════════════
    print("\n[F1] Forest plot ...")

    net_dvs = ['ca_gdp', 'trade_balance_gdp', 'income_balance_gdp',
               'savings_investment_gap']
    gross_dvs = ['gross_assets_gdp', 'gross_liab_gdp', 'gross_ifi', 'nfa_gdp',
                 'debt_assets_gdp', 'fdi_assets_gdp', 'port_eq_assets_gdp', 'fx_reserves_gdp']

    all_dvs = net_dvs + gross_dvs
    all_dvs = [v for v in all_dvs if v in df.columns and df[v].notna().sum() > 200]

    labels, coefs, ci_lo, ci_hi, colors = [], [], [], [], []
    for dv in all_dvs:
        r = run_gls(df, dv, base_vars, dv)
        if r and 'Z_1_coef' in r:
            short = dv.replace('_gdp', '').replace('balance_', '').replace('savings_investment_', 'S-I ')
            labels.append(short)
            coefs.append(r['Z_1_coef'])
            se = r['Z_1_se']
            ci_lo.append(r['Z_1_coef'] - 1.96 * se)
            ci_hi.append(r['Z_1_coef'] + 1.96 * se)
            colors.append('#2166ac' if dv in net_dvs else '#b2182b')

    if labels:
        fig, ax = plt.subplots(figsize=(8, max(4, len(labels) * 0.4)))
        y_pos = np.arange(len(labels))

        for i in range(len(labels)):
            ax.plot([ci_lo[i], ci_hi[i]], [y_pos[i], y_pos[i]],
                    color=colors[i], linewidth=2, alpha=0.7)
            ax.plot(coefs[i], y_pos[i], 'o', color=colors[i], markersize=8)

        ax.axvline(x=0, color='gray', linestyle='--', alpha=0.5)
        ax.set_yticks(y_pos)
        ax.set_yticklabels(labels)
        ax.set_xlabel('Z₁ Coefficient (95% CI)')
        ax.set_title('Figure 1: Z₁ Effect on Net vs Gross External Positions')
        ax.invert_yaxis()

        # Legend
        from matplotlib.lines import Line2D
        legend_elements = [
            Line2D([0], [0], color='#2166ac', marker='o', label='Net (CA/trade/income)'),
            Line2D([0], [0], color='#b2182b', marker='o', label='Gross (assets/liab/IIP)'),
        ]
        ax.legend(handles=legend_elements, loc='lower right')

        plt.tight_layout()
        fig.savefig(OUT_FIGURES / "forest_plot_z1.png", dpi=150)
        plt.close()
        print(f"  Saved: forest_plot_z1.png")

    # ══════════════════════════════════════════════════════════════════
    # FIGURE 2: KAOPEN gating — by tercile
    # ══════════════════════════════════════════════════════════════════
    print("\n[F2] KAOPEN gating plot ...")

    if 'kaopen' in df.columns:
        df['kaopen_tercile'] = pd.qcut(df['kaopen'].dropna(), 3,
                                        labels=['Closed', 'Mid', 'Open'],
                                        duplicates='drop')

        # Z on gross_ifi vs ca_gdp by KAOPEN tercile
        kaopen_results = {'gross_ifi': [], 'ca_gdp': []}
        # Use non-interaction spec for each tercile
        demo_only = DEMO_VARS + [c for c in controls if c != 'kaopen']

        for tercile in ['Closed', 'Mid', 'Open']:
            sub = df[df['kaopen_tercile'] == tercile].copy()
            for dv in ['gross_ifi', 'ca_gdp']:
                if dv in sub.columns:
                    r = run_gls(sub, dv, demo_only, f'{tercile}: {dv}')
                    if r and 'Z_1_coef' in r:
                        kaopen_results[dv].append({
                            'tercile': tercile,
                            'coef': r['Z_1_coef'],
                            'se': r['Z_1_se'],
                            'p': r['Z_1_p'],
                        })

        if kaopen_results['gross_ifi'] and kaopen_results['ca_gdp']:
            fig, ax = plt.subplots(figsize=(8, 5))
            x = np.arange(3)
            width = 0.35

            for i, (dv, color, label) in enumerate([
                ('gross_ifi', '#b2182b', 'Gross IFI'),
                ('ca_gdp', '#2166ac', 'CA/GDP'),
            ]):
                vals = [r['coef'] for r in kaopen_results[dv]]
                errs = [1.96 * r['se'] for r in kaopen_results[dv]]
                offset = -width/2 + i * width
                bars = ax.bar(x + offset, vals, width, yerr=errs,
                              label=label, color=color, alpha=0.8, capsize=4)

                # Add significance stars
                for j, r in enumerate(kaopen_results[dv]):
                    sig = stars(r['p'])
                    if sig:
                        ax.text(x[j] + offset, vals[j] + errs[j] + 1,
                                sig, ha='center', fontsize=10)

            ax.set_xticks(x)
            ax.set_xticklabels(['Closed', 'Mid', 'Open'])
            ax.set_xlabel('KAOPEN Tercile')
            ax.set_ylabel('Z₁ Coefficient')
            ax.set_title('Figure 2: KAOPEN Gating — Z₁ on Gross IFI vs CA/GDP')
            ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
            ax.legend()

            plt.tight_layout()
            fig.savefig(OUT_FIGURES / "kaopen_gating.png", dpi=150)
            plt.close()
            print(f"  Saved: kaopen_gating.png")

    # ══════════════════════════════════════════════════════════════════
    # FIGURE 3: Bilateral vs multilateral comparison
    # ══════════════════════════════════════════════════════════════════
    print("\n[F3] Bilateral vs multilateral comparison ...")

    bil_vars = ['agg_portfolio_total_gdp', 'agg_portfolio_debt_gdp']
    iip_vars = ['gross_assets_gdp', 'debt_assets_gdp']
    bil_vars = [v for v in bil_vars if v in df.columns]
    iip_vars = [v for v in iip_vars if v in df.columns]

    if bil_vars and iip_vars:
        # Intersection sample
        mask = df[bil_vars[0]].notna() & df[iip_vars[0]].notna()
        df_int = df[mask].copy()

        pairs = list(zip(bil_vars, iip_vars))
        bil_coefs, iip_coefs, bil_errs, iip_errs, pair_labels = [], [], [], [], []

        for bil_v, iip_v in pairs:
            r_bil = run_gls(df_int, bil_v, base_vars, f'Bilat: {bil_v}')
            r_iip = run_gls(df_int, iip_v, base_vars, f'IIP: {iip_v}')

            if r_bil and r_iip and 'Z_1_coef' in r_bil and 'Z_1_coef' in r_iip:
                short = bil_v.replace('agg_portfolio_', '').replace('_gdp', '')
                pair_labels.append(short)
                bil_coefs.append(r_bil['Z_1_coef'])
                iip_coefs.append(r_iip['Z_1_coef'])
                bil_errs.append(1.96 * r_bil['Z_1_se'])
                iip_errs.append(1.96 * r_iip['Z_1_se'])

        if pair_labels:
            fig, ax = plt.subplots(figsize=(7, 5))
            x = np.arange(len(pair_labels))
            width = 0.35

            ax.bar(x - width/2, bil_coefs, width, yerr=bil_errs,
                   label='Bilateral Aggregated', color='#ef8a62', capsize=4)
            ax.bar(x + width/2, iip_coefs, width, yerr=iip_errs,
                   label='Multilateral IIP', color='#67a9cf', capsize=4)

            ax.set_xticks(x)
            ax.set_xticklabels(pair_labels)
            ax.set_ylabel('Z₁ Coefficient')
            ax.set_title('Figure 3: Bilateral vs Multilateral Z₁ Coefficients')
            ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
            ax.legend()

            plt.tight_layout()
            fig.savefig(OUT_FIGURES / "bilateral_vs_multilateral.png", dpi=150)
            plt.close()
            print(f"  Saved: bilateral_vs_multilateral.png")

    # ══════════════════════════════════════════════════════════════════
    # FIGURE 4: Mediation path diagram
    # ══════════════════════════════════════════════════════════════════
    print("\n[F4] Mediation path diagram ...")

    # Collect coefficients for paths
    paths = {}

    # Z → CA
    r = run_gls(df, 'ca_gdp', base_vars, 'Z→CA')
    if r and 'Z_1_coef' in r:
        paths['Z→CA'] = (r['Z_1_coef'], r['Z_1_p'])

    # Z → S-I gap
    if 'savings_investment_gap' in df.columns:
        r = run_gls(df, 'savings_investment_gap', base_vars, 'Z→SI')
        if r and 'Z_1_coef' in r:
            paths['Z→S-I'] = (r['Z_1_coef'], r['Z_1_p'])

    # Z → gross_ifi
    r = run_gls(df, 'gross_ifi', base_vars, 'Z→IFI')
    if r and 'Z_1_coef' in r:
        paths['Z→Gross IFI'] = (r['Z_1_coef'], r['Z_1_p'])

    # Z → debt_assets
    if 'debt_assets_gdp' in df.columns:
        r = run_gls(df, 'debt_assets_gdp', base_vars, 'Z→Debt')
        if r and 'Z_1_coef' in r:
            paths['Z→Debt Assets'] = (r['Z_1_coef'], r['Z_1_p'])

    # Z → income_balance
    if 'income_balance_gdp' in df.columns:
        r = run_gls(df, 'income_balance_gdp', base_vars, 'Z→Income')
        if r and 'Z_1_coef' in r:
            paths['Z→Income Bal'] = (r['Z_1_coef'], r['Z_1_p'])

    if paths:
        fig, ax = plt.subplots(figsize=(10, 7))
        ax.set_xlim(0, 10)
        ax.set_ylim(0, 8)
        ax.axis('off')

        # Boxes
        boxes = {
            'Z₁': (1, 4), 'S-I Gap': (5, 6.5), 'Gross IFI': (5, 4),
            'Debt Assets': (5, 1.5), 'CA/GDP': (9, 5.5), 'Income Bal': (9, 2.5),
        }
        for label, (x, y) in boxes.items():
            bbox = dict(boxstyle='round,pad=0.3', facecolor='lightblue', alpha=0.8)
            ax.text(x, y, label, fontsize=11, ha='center', va='center', bbox=bbox)

        # Arrows with coefficients
        arrow_specs = [
            ('Z→S-I', (1.8, 4.3), (4.2, 6.2)),
            ('Z→Gross IFI', (1.8, 4), (4.2, 4)),
            ('Z→Debt Assets', (1.8, 3.7), (4.2, 1.8)),
            ('Z→CA', (1.8, 4.5), (8.2, 5.5)),
            ('Z→Income Bal', (1.8, 3.5), (8.2, 2.8)),
        ]

        for path_name, (x1, y1), (x2, y2) in arrow_specs:
            if path_name in paths:
                coef, p = paths[path_name]
                sig = stars(p)
                ax.annotate('', xy=(x2, y2), xytext=(x1, y1),
                           arrowprops=dict(arrowstyle='->', color='black', lw=1.5))
                mid_x = (x1 + x2) / 2
                mid_y = (y1 + y2) / 2
                color = 'darkgreen' if p < 0.05 else ('orange' if p < 0.1 else 'red')
                ax.text(mid_x, mid_y + 0.3, f'{coef:.1f}{sig}',
                       fontsize=9, ha='center', color=color, fontweight='bold')

        ax.set_title('Figure 4: Mediation Paths — Demographics to External Positions',
                     fontsize=13, fontweight='bold', pad=20)

        plt.tight_layout()
        fig.savefig(OUT_FIGURES / "mediation_diagram.png", dpi=150)
        plt.close()
        print(f"  Saved: mediation_diagram.png")

    print("\n" + "=" * 70)
    print("PHASE 9 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
